{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/matplotlib/font_manager.py:278: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
      "  'Matplotlib is building the font cache using fc-list. '\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch as th\n",
    "from torch.autograd import Variable\n",
    "from matplotlib import pyplot as plt\n",
    "import random\n",
    "import sys\n",
    "sys.path.append(\"/Users/brandonbrown/Desktop/Projects/pycolab/\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def randPair(s,e):\n",
    "    return np.random.randint(s,e), np.random.randint(s,e)\n",
    "\n",
    "class BoardPiece:\n",
    "    \n",
    "    def __init__(self, name, code, pos):\n",
    "        self.name = name #name of the piece\n",
    "        self.code = code #an ASCII character to display on the board\n",
    "        self.pos = pos #2-tuple e.g. (1,4)\n",
    "        \n",
    "class BoardMask:\n",
    "    \n",
    "    def __init__(self, name, mask, code):\n",
    "        self.name = name\n",
    "        self.mask = mask\n",
    "        self.code = code\n",
    "        \n",
    "    def get_positions(self): #returns tuple of arrays\n",
    "        return np.nonzero(self.mask)\n",
    "\n",
    "def zip_positions2d(positions): #positions is tuple of two arrays\n",
    "    x,y = positions\n",
    "    return list(zip(x,y))\n",
    "\n",
    "class GridBoard:\n",
    "    \n",
    "    def __init__(self, size=4):\n",
    "        self.size = size #Board dimensions, e.g. 4 x 4\n",
    "        self.components = {} #name : board piece\n",
    "        self.masks = {}\n",
    "    \n",
    "    def addPiece(self, name, code, pos=(0,0)):\n",
    "        newPiece = BoardPiece(name, code, pos)\n",
    "        self.components[name] = newPiece\n",
    "        \n",
    "    #basically a set of boundary elements\n",
    "    def addMask(self, name, mask, code):\n",
    "        #mask is a 2D-numpy array with 1s where the boundary elements are\n",
    "        newMask = BoardMask(name, mask, code)\n",
    "        self.masks[name] = newMask\n",
    "    \n",
    "    def movePiece(self, name, pos):\n",
    "        move = True\n",
    "        for _, mask in self.masks.items():\n",
    "            if pos in zip_positions2d(mask.get_positions()):\n",
    "                move = False\n",
    "        if move:\n",
    "            self.components[name].pos = pos\n",
    "    \n",
    "    def delPiece(self, name):\n",
    "        del self.components['name']\n",
    "    \n",
    "    def render(self):\n",
    "        dtype = '<U2'\n",
    "        displ_board = np.zeros((self.size, self.size), dtype=dtype)\n",
    "        displ_board[:] = ' '\n",
    "        \n",
    "        for name, piece in self.components.items():\n",
    "            displ_board[piece.pos] = piece.code\n",
    "            \n",
    "        for name, mask in self.masks.items():\n",
    "            displ_board[mask.get_positions()] = mask.code\n",
    "        \n",
    "        return displ_board\n",
    "    \n",
    "    def render_np(self):\n",
    "        num_pieces = len(self.components) + len(self.masks)\n",
    "        displ_board = np.zeros((num_pieces, self.size, self.size), dtype=np.uint8)\n",
    "        layer = 0\n",
    "        for name, piece in self.components.items():\n",
    "            pos = (layer,) + piece.pos\n",
    "            displ_board[pos] = 1\n",
    "            layer += 1\n",
    "            \n",
    "        for name, mask in self.masks.items():\n",
    "            x,y = game.board.masks['boundary'].get_positions()\n",
    "            z = np.repeat(layer,len(x))\n",
    "            a = (z,x,y)\n",
    "            displ_board[a] = 1\n",
    "            layer += 1\n",
    "        return displ_board"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def addTuple(a,b):\n",
    "    return tuple([sum(x) for x in zip(a,b)])\n",
    "        \n",
    "class Gridworld:\n",
    "    \n",
    "    def __init__(self, size=4, mode='static'):\n",
    "        if size >= 4:\n",
    "            self.board = GridBoard(size=size)\n",
    "        else:\n",
    "            print(\"Minimum board size is 4. Initialized to size 4.\")\n",
    "            self.board = GridBoard(size=4)\n",
    "        \n",
    "        #Add pieces, positions will be updated later\n",
    "        self.board.addPiece('Player','P',(0,0))\n",
    "        self.board.addPiece('Goal','+',(1,0))\n",
    "        self.board.addPiece('Pit','-',(2,0))\n",
    "        self.board.addPiece('Wall','W',(3,0))\n",
    "            \n",
    "        if mode == 'static':\n",
    "            self.initGridStatic()\n",
    "        elif mode == 'player':\n",
    "            self.initGridPlayer()\n",
    "        else:\n",
    "            self.initGridRand()\n",
    "    \n",
    "    #Initialize stationary grid, all items are placed deterministically\n",
    "    def initGridStatic(self):\n",
    "        #Setup static pieces\n",
    "        self.board.components['Player'].pos = (0,3) #Row, Column\n",
    "        self.board.components['Goal'].pos = (0,0)\n",
    "        self.board.components['Pit'].pos = (0,1)\n",
    "        self.board.components['Wall'].pos = (1,1)\n",
    "    \n",
    "    #Check if board is initialized appropriately (no overlapping pieces)\n",
    "    def validateBoard(self):\n",
    "        all_positions = [piece.pos for name,piece in self.board.components.items()]\n",
    "        if len(all_positions) > len(set(all_positions)):\n",
    "            return False\n",
    "        else:\n",
    "            return True\n",
    "\n",
    "    #Initialize player in random location, but keep wall, goal and pit stationary\n",
    "    def initGridPlayer(self):\n",
    "        #height x width x depth (number of pieces)\n",
    "        self.initGridStatic()\n",
    "        #place player\n",
    "        self.board.components['Player'].pos = randPair(0,self.board.size)\n",
    "\n",
    "        if (not self.validateBoard()):\n",
    "            #print('Invalid grid. Rebuilding..')\n",
    "            self.initGridPlayer()\n",
    "\n",
    "    #Initialize grid so that goal, pit, wall, player are all randomly placed\n",
    "    def initGridRand(self):\n",
    "        #height x width x depth (number of pieces)\n",
    "        self.board.components['Player'].pos = randPair(0,self.board.size)\n",
    "        self.board.components['Goal'].pos = randPair(0,self.board.size)\n",
    "        self.board.components['Pit'].pos = randPair(0,self.board.size)\n",
    "        self.board.components['Wall'].pos = randPair(0,self.board.size)\n",
    "\n",
    "        if (not self.validateBoard()):\n",
    "            #print('Invalid grid. Rebuilding..')\n",
    "            self.initGridRand()\n",
    "\n",
    "    def makeMove(self, action):\n",
    "        #need to determine what object (if any) is in the new grid spot the player is moving to\n",
    "        #actions in {u,d,l,r}\n",
    "        def checkMove(addpos=(0,0)):\n",
    "            new_pos = addTuple(self.board.components['Player'].pos, addpos)\n",
    "            if new_pos == self.board.components['Wall'].pos:\n",
    "                pass #block move, player can't move to wall\n",
    "            elif max(new_pos) > (self.board.size-1):    #if outside bounds of board\n",
    "                pass\n",
    "            elif min(new_pos) < 0: #if outside bounds\n",
    "                pass\n",
    "            else:\n",
    "                self.board.movePiece('Player', new_pos)\n",
    "        if action == 'u': #up\n",
    "            checkMove((-1,0))\n",
    "        elif action == 'd': #down\n",
    "            checkMove((1,0))\n",
    "        elif action == 'l': #left\n",
    "            checkMove((0,-1))\n",
    "        elif action == 'r': #right\n",
    "            checkMove((0,1))\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "    def reward(self):\n",
    "        if (self.board.components['Player'].pos == self.board.components['Pit'].pos):\n",
    "            return -10\n",
    "        elif (self.board.components['Player'].pos == self.board.components['Goal'].pos):\n",
    "            return 10\n",
    "        else:\n",
    "            return -1\n",
    "\n",
    "    def display(self):\n",
    "        return self.board.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', 'P'],\n",
       "       [' ', 'W', ' ', ' '],\n",
       "       [' ', ' ', ' ', ' '],\n",
       "       [' ', ' ', ' ', ' ']], dtype='<U2')"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game = Gridworld(size=4, mode='static')\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', ' '],\n",
       "       [' ', 'W', ' ', ' '],\n",
       "       [' ', ' ', 'P', ' '],\n",
       "       [' ', ' ', ' ', ' ']], dtype='<U2')"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.makeMove('d')\n",
    "game.makeMove('d')\n",
    "game.makeMove('l')\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', ' '],\n",
       "       [' ', 'W', ' ', ' '],\n",
       "       [' ', ' ', ' ', ' '],\n",
       "       [' ', ' ', 'P', ' ']], dtype='<U2')"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.makeMove('d') # (0,3) + (-1,3)\n",
    "print(game.reward())\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 1, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[1, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 0, 0, 0],\n",
       "        [0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]]], dtype=uint8)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.board.render_np()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', ' '],\n",
       "       [' ', 'W', ' ', 'P'],\n",
       "       [' ', ' ', ' ', '#'],\n",
       "       ['#', '#', '#', '#']], dtype='<U2')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask1 = np.array([[0, 0, 0, 0],\n",
    "        [0, 0, 0, 0],\n",
    "        [0, 0, 0, 1],\n",
    "        [1, 1, 1, 1]])\n",
    "game.board.addMask('boundary', mask1, '#')\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', ' '],\n",
       "       [' ', 'W', ' ', 'P'],\n",
       "       [' ', ' ', ' ', '#'],\n",
       "       ['#', '#', '#', '#']], dtype='<U2')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.makeMove('d') # (0,3) + (-1,3)\n",
    "print(game.reward())\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0, 0, 0, 0],\n",
       "        [0, 0, 0, 1],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[1, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 0, 0, 0],\n",
       "        [0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 1],\n",
       "        [1, 1, 1, 1]]], dtype=uint8)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.board.render_np()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Sokoban:\n",
    "    \n",
    "    def __init__(self, size=5, mode='static'):\n",
    "        if size >= 5:\n",
    "            self.board = GridBoard(size=size)\n",
    "        else:\n",
    "            print(\"Minimum board size is 5. Initialized to size 5.\")\n",
    "            self.board = GridBoard(size=5)\n",
    "        \n",
    "        #Add pieces, positions will be updated later\n",
    "        self.board.addPiece('Player','P',(0,0))\n",
    "        self.board.addPiece('Goal','+',(0,4))\n",
    "        self.board.addPiece('Box','B',(2,1))\n",
    "        \n",
    "        mask = np.array([\n",
    "            [0, 0, 1, 0, 0],\n",
    "            [0, 0, 1, 0, 0],\n",
    "            [1, 0, 1, 1, 0],\n",
    "            [1, 0, 0, 0, 0],\n",
    "            [1, 0, 1, 1, 0]])\n",
    "        self.board.addMask('boundary', mask, code='#')\n",
    "            \n",
    "        if mode != 'static':\n",
    "            self.initGridRand()\n",
    "    \n",
    "    def initGridRand(self):\n",
    "        self.board.components['Player'].pos = (np.random.randint(0,2), np.random.randint(0,2))\n",
    "        self.board.components['Goal'].pos = (np.random.randint(0,2), np.random.randint(3,5))\n",
    "        \n",
    "    def makeMove(self, action):\n",
    "        #need to determine what object (if any) is in the new grid spot the player is moving to\n",
    "        #actions in {u,d,l,r}\n",
    "        def checkMove(addpos=(0,0)):\n",
    "            new_pos = addTuple(self.board.components['Player'].pos, addpos)\n",
    "            if max(new_pos) > (self.board.size-1):    #if outside bounds of board\n",
    "                pass\n",
    "            elif min(new_pos) < 0: #if outside bounds\n",
    "                pass\n",
    "            else:\n",
    "                self.board.movePiece('Player', new_pos)\n",
    "        if action == 'u': #up\n",
    "            checkMove((-1,0))\n",
    "        elif action == 'd': #down\n",
    "            checkMove((1,0))\n",
    "        elif action == 'l': #left\n",
    "            checkMove((0,-1))\n",
    "        elif action == 'r': #right\n",
    "            checkMove((0,1))\n",
    "        else:\n",
    "            pass\n",
    "        \n",
    "    def getReward(self):\n",
    "        if (self.board.components['Player'].pos == self.board.components['Goal'].pos):\n",
    "            return 10\n",
    "        else:\n",
    "            return -1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:deeprl]",
   "language": "python",
   "name": "conda-env-deeprl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
